{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {},
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/master/tutorials/W3D1_BayesianDecisions/W3D1_Tutorial1.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a> &nbsp; <a href=\"https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/NeoNeuron/professional-workshop-3/master/tutorials/W7_BayesianDecisions/W7_Tutorial1.ipynb\" target=\"_parent\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" alt=\"Open in Kaggle\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "# Tutorial 1: Bayes with a binary hidden state\n",
    "\n",
    "__Content creators:__ Eric DeWitt, Xaq Pitkow, Ella Batty, Saeed Salehi\n",
    "\n",
    "__Content modified:__ Kai Chen\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "# Tutorial Objectives\n",
    "\n",
    "This is the first of two core tutorials on Bayesian statistics. In these tutorials, we will explore the fundemental concepts of the Bayesian approach. In this tutorial you will work through an example of Bayesian inference and decision making using a binary hidden state. The second tutorial extends these concepts to a continuous hidden state. In the next few weeks, each of these basic ideas will be extended. In Hidden Dynamics, we consider these idea through time as you explore what happens when we infere a hidden state using repeated observations and when the hidden state changes across time. In the Optimal Control week, we will introduce the notion of how to use inference and decisions to select actions for optimal control.\n",
    "\n",
    "This notebook will introduce the fundamental building blocks for Bayesian statistics: \n",
    "\n",
    "1. How do we combine the possible loss (or gain) for making a decision with our probabilitic knowledge?\n",
    "2. How do we use probability distributions to represent hidden states?\n",
    "3. How does marginalization work and how can we use it?\n",
    "4. How do we combine new information with our prior knowledge? "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "# Setup  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# Imports\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import patches, transforms, gridspec\n",
    "from scipy.optimize import fsolve\n",
    "\n",
    "from collections import namedtuple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "#@title Figure Settings\n",
    "import ipywidgets as widgets       # interactive display\n",
    "from ipywidgets import GridspecLayout, HBox, VBox, FloatSlider, Layout, ToggleButtons\n",
    "from ipywidgets import interactive, interactive_output, Checkbox, Select\n",
    "from IPython.display import clear_output\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "\n",
    "nma_style = {\n",
    "    'figure.figsize' : (8, 6),\n",
    "    'figure.autolayout' : True,\n",
    "    'font.size' : 15,\n",
    "    'xtick.labelsize' : 'small',\n",
    "    'ytick.labelsize' : 'small',\n",
    "    'legend.fontsize' : 'small',\n",
    "    'axes.spines.top' : False,\n",
    "    'axes.spines.right' : False,\n",
    "    'xtick.major.size' : 5,\n",
    "    'ytick.major.size' : 5,\n",
    "}\n",
    "for key, value in nma_style.items():\n",
    "    plt.rcParams[key] = value\n",
    "\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @title Plotting Functions\n",
    "\n",
    "def plot_joint_probs(P, ):\n",
    "    assert np.all(P >= 0), \"probabilities should be >= 0\"\n",
    "    # normalize if not\n",
    "    P = P / np.sum(P)\n",
    "    marginal_y = np.sum(P,axis=1)\n",
    "    marginal_x = np.sum(P,axis=0)\n",
    "\n",
    "    # definitions for the axes\n",
    "    left, width = 0.1, 0.65\n",
    "    bottom, height = 0.1, 0.65\n",
    "    spacing = 0.005\n",
    "\n",
    "    # start with a square Figure\n",
    "    fig = plt.figure(figsize=(5, 5))\n",
    "\n",
    "    joint_prob = [left, bottom, width, height]\n",
    "    rect_histx = [left, bottom + height + spacing, width, 0.2]\n",
    "    rect_histy = [left + width + spacing, bottom, 0.2, height]\n",
    "\n",
    "    rect_x_cmap = plt.cm.Blues\n",
    "    rect_y_cmap = plt.cm.Reds\n",
    "\n",
    "    # Show joint probs and marginals\n",
    "    ax = fig.add_axes(joint_prob)\n",
    "    ax_x = fig.add_axes(rect_histx, sharex=ax)\n",
    "    ax_y = fig.add_axes(rect_histy, sharey=ax)\n",
    "\n",
    "    # Show joint probs and marginals\n",
    "    ax.matshow(P,vmin=0., vmax=1., cmap='Greys')\n",
    "    ax_x.bar(0, marginal_x[0], facecolor=rect_x_cmap(marginal_x[0]))\n",
    "    ax_x.bar(1, marginal_x[1], facecolor=rect_x_cmap(marginal_x[1]))\n",
    "    ax_y.barh(0, marginal_y[0], facecolor=rect_y_cmap(marginal_y[0]))\n",
    "    ax_y.barh(1, marginal_y[1], facecolor=rect_y_cmap(marginal_y[1]))\n",
    "    # set limits\n",
    "    ax_x.set_ylim([0,1])\n",
    "    ax_y.set_xlim([0,1])\n",
    "\n",
    "    # show values\n",
    "    ind = np.arange(2)\n",
    "    x,y = np.meshgrid(ind,ind)\n",
    "    for i,j in zip(x.flatten(), y.flatten()):\n",
    "        c = f\"{P[i,j]:.2f}\"\n",
    "        ax.text(j,i, c, va='center', ha='center', color='black')\n",
    "    for i in ind:\n",
    "        v = marginal_x[i]\n",
    "        c = f\"{v:.2f}\"\n",
    "        ax_x.text(i, v +0.1, c, va='center', ha='center', color='black')\n",
    "        v = marginal_y[i]\n",
    "        c = f\"{v:.2f}\"\n",
    "        ax_y.text(v+0.2, i, c, va='center', ha='center', color='black')\n",
    "\n",
    "    # set up labels\n",
    "    ax.xaxis.tick_bottom()\n",
    "    ax.yaxis.tick_left()\n",
    "    ax.set_xticks([0,1])\n",
    "    ax.set_yticks([0,1])\n",
    "    ax.set_xticklabels(['Silver','Gold'])\n",
    "    ax.set_yticklabels(['Small', 'Large'])\n",
    "    ax.set_xlabel('color')\n",
    "    ax.set_ylabel('size')\n",
    "    ax_x.axis('off')\n",
    "    ax_y.axis('off')\n",
    "    return fig\n",
    "\n",
    "\n",
    "def plot_prior_likelihood_posterior(prior, likelihood, posterior):\n",
    "\n",
    "    # definitions for the axes\n",
    "    left, width = 0.05, 0.3\n",
    "    bottom, height = 0.05, 0.9\n",
    "    padding = 0.12\n",
    "    small_width = 0.1\n",
    "    left_space = left + small_width + padding\n",
    "    added_space = padding + width\n",
    "\n",
    "    fig = plt.figure(figsize=(12, 4))\n",
    "\n",
    "    rect_prior = [left, bottom, small_width, height]\n",
    "    rect_likelihood = [left_space , bottom , width, height]\n",
    "    rect_posterior = [left_space +  added_space, bottom , width, height]\n",
    "\n",
    "    ax_prior = fig.add_axes(rect_prior)\n",
    "    ax_likelihood = fig.add_axes(rect_likelihood, sharey=ax_prior)\n",
    "    ax_posterior = fig.add_axes(rect_posterior, sharey = ax_prior)\n",
    "\n",
    "    rect_colormap = plt.cm.Blues\n",
    "\n",
    "    # Show posterior probs and marginals\n",
    "    ax_prior.barh(0, prior[0], facecolor = rect_colormap(prior[0, 0]))\n",
    "    ax_prior.barh(1, prior[1], facecolor = rect_colormap(prior[1, 0]))\n",
    "    ax_likelihood.matshow(likelihood, vmin=0., vmax=1., cmap='Reds')\n",
    "    ax_posterior.matshow(posterior, vmin=0., vmax=1., cmap='Greens')\n",
    "\n",
    "\n",
    "    # Probabilities plot details\n",
    "    # ax_prior.set(xlim = [1, 0], yticks = [0, 1], yticklabels = ['left', 'right'],\n",
    "    #              ylabel = 'state (s)', title = \"Prior p(s)\")\n",
    "    ax_prior.set(xlim = [1, 0], xticks = [], yticks = [0, 1], yticklabels = ['left', 'right'],\n",
    "                 title = \"Prior p(s)\")\n",
    "    ax_prior.yaxis.tick_right()\n",
    "    ax_prior.spines['left'].set_visible(False)\n",
    "    ax_prior.spines['bottom'].set_visible(False)\n",
    "\n",
    "    # Likelihood plot details\n",
    "    ax_likelihood.set(xticks = [0, 1], xticklabels = ['fish', 'no fish'],\n",
    "                  yticks = [0, 1], yticklabels = ['left', 'right'],\n",
    "                   ylabel = 'state (s)', xlabel = 'measurement (m)',\n",
    "                   title = 'Likelihood p(m (left) | s)')\n",
    "    ax_likelihood.xaxis.set_ticks_position('bottom')\n",
    "    ax_likelihood.spines['left'].set_visible(False)\n",
    "    ax_likelihood.spines['bottom'].set_visible(False)\n",
    "\n",
    "    # Posterior plot details\n",
    "\n",
    "    ax_posterior.set(xticks = [0, 1], xticklabels = ['fish', 'no fish'],\n",
    "                  yticks = [0, 1], yticklabels = ['left', 'right'],\n",
    "                   ylabel = 'state (s)', xlabel = 'measurement (m)',\n",
    "                   title = 'Posterior p(s | m)')\n",
    "    ax_posterior.xaxis.set_ticks_position('bottom')\n",
    "    ax_posterior.spines['left'].set_visible(False)\n",
    "    ax_posterior.spines['bottom'].set_visible(False)\n",
    "\n",
    "\n",
    "    # show values\n",
    "    ind = np.arange(2)\n",
    "    x,y = np.meshgrid(ind,ind)\n",
    "    for i,j in zip(x.flatten(), y.flatten()):\n",
    "        c = f\"{posterior[i,j]:.2f}\"\n",
    "        ax_posterior.text(j,i, c, va='center', ha='center', color='black')\n",
    "    for i,j in zip(x.flatten(), y.flatten()):\n",
    "        c = f\"{likelihood[i,j]:.2f}\"\n",
    "        ax_likelihood.text(j,i, c, va='center', ha='center', color='black')\n",
    "    for i in ind:\n",
    "        v = prior[i, 0]\n",
    "        c = f\"{v:.2f}\"\n",
    "        ax_prior.text(v+0.2, i, c, va='center', ha='center', color='black')\n",
    "\n",
    "\n",
    "def plot_prior_likelihood(ps, p_a_s1, p_a_s0, measurement):\n",
    "    likelihood = np.asarray([[p_a_s1, 1-p_a_s1],[p_a_s0, 1-p_a_s0]])\n",
    "    assert 0.0 <= ps <= 1.0\n",
    "    prior = np.asarray([ps, 1 - ps])\n",
    "    if measurement == \"Fish\":\n",
    "        posterior = likelihood[:, 0] * prior\n",
    "    else:\n",
    "        posterior = (likelihood[:, 1] * prior).reshape(-1)\n",
    "    posterior /= np.sum(posterior)\n",
    "\n",
    "    # definitions for the axes\n",
    "    left, width = 0.05, 0.3\n",
    "    bottom, height = 0.05, 0.9\n",
    "    padding = 0.12\n",
    "    small_width = 0.2\n",
    "    left_space = left + small_width + padding\n",
    "    small_padding = 0.05\n",
    "\n",
    "    fig = plt.figure(figsize=(12, 4))\n",
    "\n",
    "    rect_prior = [left, bottom, small_width, height]\n",
    "    rect_likelihood = [left_space , bottom , width, height]\n",
    "    rect_posterior = [left_space + width + small_padding, bottom , small_width, height]\n",
    "\n",
    "    ax_prior = fig.add_axes(rect_prior)\n",
    "    ax_likelihood = fig.add_axes(rect_likelihood, sharey=ax_prior)\n",
    "    ax_posterior = fig.add_axes(rect_posterior, sharey=ax_prior)\n",
    "\n",
    "    prior_colormap = plt.cm.Blues\n",
    "    posterior_colormap = plt.cm.Greens\n",
    "\n",
    "    # Show posterior probs and marginals\n",
    "    ax_prior.barh(0, prior[0], facecolor = prior_colormap(prior[0]))\n",
    "    ax_prior.barh(1, prior[1], facecolor = prior_colormap(prior[1]))\n",
    "    ax_likelihood.matshow(likelihood, vmin=0., vmax=1., cmap='Reds')\n",
    "    # ax_posterior.matshow(posterior, vmin=0., vmax=1., cmap='')\n",
    "    ax_posterior.barh(0, posterior[0], facecolor = posterior_colormap(posterior[0]))\n",
    "    ax_posterior.barh(1, posterior[1], facecolor = posterior_colormap(posterior[1]))\n",
    "\n",
    "    # Probabilities plot details\n",
    "    ax_prior.set(xlim = [1, 0], yticks = [0, 1], yticklabels = ['left', 'right'],\n",
    "                 title = \"Prior p(s)\", xticks = [])\n",
    "    ax_prior.yaxis.tick_right()\n",
    "    ax_prior.spines['left'].set_visible(False)\n",
    "    ax_prior.spines['bottom'].set_visible(False)\n",
    "\n",
    "    # Likelihood plot details\n",
    "    ax_likelihood.set(xticks = [0, 1], xticklabels = ['fish', 'no fish'],\n",
    "                  yticks = [0, 1], yticklabels = ['left', 'right'],\n",
    "                   ylabel = 'state (s)', xlabel = 'measurement (m)',\n",
    "                   title = 'Likelihood p(m | s)')\n",
    "    ax_likelihood.xaxis.set_ticks_position('bottom')\n",
    "    ax_likelihood.spines['left'].set_visible(False)\n",
    "    ax_likelihood.spines['bottom'].set_visible(False)\n",
    "\n",
    "    # Posterior plot details\n",
    "    ax_posterior.set(xlim = [0, 1], xticks = [], yticks = [0, 1],\n",
    "                     yticklabels = ['left', 'right'], title = \"Posterior p(s | m)\")\n",
    "    ax_posterior.spines['left'].set_visible(False)\n",
    "    ax_posterior.spines['bottom'].set_visible(False)\n",
    "\n",
    "    # show values\n",
    "    ind = np.arange(2)\n",
    "    x,y = np.meshgrid(ind,ind)\n",
    "    # for i,j in zip(x.flatten(), y.flatten()):\n",
    "    #     c = f\"{posterior[i,j]:.2f}\"\n",
    "    #     ax_posterior.text(j,i, c, va='center', ha='center', color='black')\n",
    "    for i in ind:\n",
    "        v = posterior[i]\n",
    "        c = f\"{v:.2f}\"\n",
    "        ax_posterior.text(v+0.2, i, c, va='center', ha='center', color='black')\n",
    "    for i,j in zip(x.flatten(), y.flatten()):\n",
    "        c = f\"{likelihood[i,j]:.2f}\"\n",
    "        ax_likelihood.text(j,i, c, va='center', ha='center', color='black')\n",
    "    for i in ind:\n",
    "        v = prior[i]\n",
    "        c = f\"{v:.2f}\"\n",
    "        ax_prior.text(v+0.2, i, c, va='center', ha='center', color='black')\n",
    "    return fig\n",
    "\n",
    "\n",
    "from matplotlib import colors\n",
    "def plot_utility(ps):\n",
    "    prior = np.asarray([ps, 1 - ps])\n",
    "\n",
    "    utility = np.array([[2, -3], [-2, 1]])\n",
    "\n",
    "    expected = prior @ utility\n",
    "\n",
    "    # definitions for the axes\n",
    "    left, width = 0.05, 0.16\n",
    "    bottom, height = 0.05, 0.9\n",
    "    padding = 0.02\n",
    "    small_width = 0.1\n",
    "    left_space = left + small_width + padding\n",
    "    added_space = padding + width\n",
    "\n",
    "    fig = plt.figure(figsize=(17, 3))\n",
    "\n",
    "    rect_prior = [left, bottom, small_width, height]\n",
    "    rect_utility = [left + added_space , bottom , width, height]\n",
    "    rect_expected = [left + 2* added_space, bottom , width, height]\n",
    "\n",
    "    ax_prior = fig.add_axes(rect_prior)\n",
    "    ax_utility = fig.add_axes(rect_utility, sharey=ax_prior)\n",
    "    ax_expected = fig.add_axes(rect_expected)\n",
    "\n",
    "    rect_colormap = plt.cm.Blues\n",
    "\n",
    "    # Data of plots\n",
    "    ax_prior.barh(0, prior[0], facecolor = rect_colormap(prior[0]))\n",
    "    ax_prior.barh(1, prior[1], facecolor = rect_colormap(prior[1]))\n",
    "    ax_utility.matshow(utility, cmap='cool')\n",
    "    norm = colors.Normalize(vmin=-3, vmax=3)\n",
    "    ax_expected.bar(0, expected[0], facecolor = rect_colormap(norm(expected[0])))\n",
    "    ax_expected.bar(1, expected[1], facecolor = rect_colormap(norm(expected[1])))\n",
    "\n",
    "    # Probabilities plot details\n",
    "    ax_prior.set(xlim = [1, 0], xticks = [], yticks = [0, 1], yticklabels = ['left', 'right'],\n",
    "                 title = \"Probability of state\")\n",
    "    ax_prior.yaxis.tick_right()\n",
    "    ax_prior.spines['left'].set_visible(False)\n",
    "    ax_prior.spines['bottom'].set_visible(False)\n",
    "\n",
    "    # Utility plot details\n",
    "    ax_utility.set(xticks = [0, 1], xticklabels = ['left', 'right'],\n",
    "                  yticks = [0, 1], yticklabels = ['left', 'right'],\n",
    "                   ylabel = 'state (s)', xlabel = 'action (a)',\n",
    "                   title = 'Utility')\n",
    "    ax_utility.xaxis.set_ticks_position('bottom')\n",
    "    ax_utility.spines['left'].set_visible(False)\n",
    "    ax_utility.spines['bottom'].set_visible(False)\n",
    "\n",
    "    # Expected utility plot details\n",
    "    ax_expected.set(title = 'Expected utility', ylim = [-3, 3],\n",
    "                    xticks = [0, 1], xticklabels = ['left', 'right'],\n",
    "                    xlabel = 'action (a)',\n",
    "                    yticks = [])\n",
    "    ax_expected.xaxis.set_ticks_position('bottom')\n",
    "    ax_expected.spines['left'].set_visible(False)\n",
    "    ax_expected.spines['bottom'].set_visible(False)\n",
    "\n",
    "    # show values\n",
    "    ind = np.arange(2)\n",
    "    x,y = np.meshgrid(ind,ind)\n",
    "\n",
    "    for i,j in zip(x.flatten(), y.flatten()):\n",
    "        c = f\"{utility[i,j]:.2f}\"\n",
    "        ax_utility.text(j,i, c, va='center', ha='center', color='black')\n",
    "    for i in ind:\n",
    "        v = prior[i]\n",
    "        c = f\"{v:.2f}\"\n",
    "        ax_prior.text(v+0.2, i, c, va='center', ha='center', color='black')\n",
    "    for i in ind:\n",
    "        v = expected[i]\n",
    "        c = f\"{v:.2f}\"\n",
    "        ax_expected.text(i, 2.5, c, va='center', ha='center', color='black')\n",
    "\n",
    "    return fig\n",
    "\n",
    "\n",
    "def plot_prior_likelihood_utility(ps, p_a_s1, p_a_s0, measurement):\n",
    "    assert 0.0 <= ps <= 1.0\n",
    "    assert 0.0 <= p_a_s1 <= 1.0\n",
    "    assert 0.0 <= p_a_s0 <= 1.0\n",
    "    prior = np.asarray([ps, 1 - ps])\n",
    "    likelihood = np.asarray([[p_a_s1, 1-p_a_s1],[p_a_s0, 1-p_a_s0]])\n",
    "    utility = np.array([[2.0, -3.0], [-2.0, 1.0]])\n",
    "    # expected = np.zeros_like(utility)\n",
    "\n",
    "    if measurement == \"Fish\":\n",
    "        posterior = likelihood[:, 0] * prior\n",
    "    else:\n",
    "        posterior = (likelihood[:, 1] * prior).reshape(-1)\n",
    "    posterior /= np.sum(posterior)\n",
    "    # expected[:, 0] = utility[:, 0] * posterior\n",
    "    # expected[:, 1] = utility[:, 1] * posterior\n",
    "    expected = posterior @ utility\n",
    "\n",
    "    # definitions for the axes\n",
    "    left, width = 0.05, 0.3\n",
    "    bottom, height = 0.05, 0.3\n",
    "    padding = 0.12\n",
    "    small_width = 0.2\n",
    "    left_space = left + small_width + padding\n",
    "    small_padding = 0.05\n",
    "\n",
    "    fig = plt.figure(figsize=(10, 9))\n",
    "\n",
    "    rect_prior = [left, bottom + height + padding, small_width, height]\n",
    "    rect_likelihood = [left_space , bottom + height + padding , width, height]\n",
    "    rect_posterior = [left_space + width + small_padding, bottom + height + padding , small_width, height]\n",
    "\n",
    "    rect_utility = [padding, bottom, width, height]\n",
    "    rect_expected = [padding + width + padding + left, bottom, width, height]\n",
    "\n",
    "    ax_likelihood = fig.add_axes(rect_likelihood)\n",
    "    ax_prior = fig.add_axes(rect_prior, sharey=ax_likelihood)\n",
    "    ax_posterior = fig.add_axes(rect_posterior, sharey=ax_likelihood)\n",
    "    ax_utility = fig.add_axes(rect_utility)\n",
    "    ax_expected = fig.add_axes(rect_expected)\n",
    "\n",
    "    prior_colormap = plt.cm.Blues\n",
    "    posterior_colormap = plt.cm.Greens\n",
    "    expected_colormap = plt.cm.Wistia\n",
    "\n",
    "    # Show posterior probs and marginals\n",
    "    ax_prior.barh(0, prior[0], facecolor = prior_colormap(prior[0]))\n",
    "    ax_prior.barh(1, prior[1], facecolor = prior_colormap(prior[1]))\n",
    "    ax_likelihood.matshow(likelihood, vmin=0., vmax=1., cmap='Reds')\n",
    "    ax_posterior.barh(0, posterior[0], facecolor = posterior_colormap(posterior[0]))\n",
    "    ax_posterior.barh(1, posterior[1], facecolor = posterior_colormap(posterior[1]))\n",
    "    ax_utility.matshow(utility, vmin=0., vmax=1., cmap='cool')\n",
    "    # ax_expected.matshow(expected, vmin=0., vmax=1., cmap='Wistia')\n",
    "    ax_expected.bar(0, expected[0], facecolor = expected_colormap(expected[0]))\n",
    "    ax_expected.bar(1, expected[1], facecolor = expected_colormap(expected[1]))\n",
    "\n",
    "    # Probabilities plot details\n",
    "    ax_prior.set(xlim = [1, 0], yticks = [0, 1], yticklabels = ['left', 'right'],\n",
    "                 title = \"Prior p(s)\", xticks = [])\n",
    "    ax_prior.yaxis.tick_right()\n",
    "    ax_prior.spines['left'].set_visible(False)\n",
    "    ax_prior.spines['bottom'].set_visible(False)\n",
    "\n",
    "    # Likelihood plot details\n",
    "    ax_likelihood.set(xticks = [0, 1], xticklabels = ['fish', 'no fish'],\n",
    "                  yticks = [0, 1], yticklabels = ['left', 'right'],\n",
    "                   ylabel = 'state (s)', xlabel = 'measurement (m)',\n",
    "                   title = 'Likelihood p(m | s)')\n",
    "    ax_likelihood.xaxis.set_ticks_position('bottom')\n",
    "    ax_likelihood.spines['left'].set_visible(False)\n",
    "    ax_likelihood.spines['bottom'].set_visible(False)\n",
    "\n",
    "    # Posterior plot details\n",
    "    ax_posterior.set(xlim = [0, 1], xticks = [], yticks = [0, 1],\n",
    "                     yticklabels = ['left', 'right'], title = \"Posterior p(s | m)\")\n",
    "    ax_posterior.spines['left'].set_visible(False)\n",
    "    ax_posterior.spines['bottom'].set_visible(False)\n",
    "\n",
    "    # Utility plot details\n",
    "    ax_utility.set(xticks = [0, 1], xticklabels = ['left', 'right'],\n",
    "                   xlabel = 'action (a)', yticks = [0, 1], yticklabels = ['left', 'right'],\n",
    "                   title = 'Utility', ylabel = 'state (s)')\n",
    "    ax_utility.xaxis.set_ticks_position('bottom')\n",
    "    ax_utility.spines['left'].set_visible(False)\n",
    "    ax_utility.spines['bottom'].set_visible(False)\n",
    "\n",
    "    # Expected Utility plot details\n",
    "    ax_expected.set(ylim = [-2, 2], xticks = [0, 1], xticklabels = ['left', 'right'],\n",
    "                 xlabel = 'action (a)', title = 'Expected utility', yticks=[])\n",
    "    # ax_expected.axis('off')\n",
    "    ax_expected.spines['left'].set_visible(False)\n",
    "    # ax_expected.set(xticks = [0, 1], xticklabels = ['left', 'right'],\n",
    "    #                 xlabel = 'action (a)',\n",
    "    #                title = 'Expected utility')\n",
    "    # ax_expected.xaxis.set_ticks_position('bottom')\n",
    "    # ax_expected.spines['left'].set_visible(False)\n",
    "    # ax_expected.spines['bottom'].set_visible(False)\n",
    "\n",
    "    # show values\n",
    "    ind = np.arange(2)\n",
    "    x,y = np.meshgrid(ind,ind)\n",
    "    for i in ind:\n",
    "        v = posterior[i]\n",
    "        c = f\"{v:.2f}\"\n",
    "        ax_posterior.text(v+0.2, i, c, va='center', ha='center', color='black')\n",
    "    for i,j in zip(x.flatten(), y.flatten()):\n",
    "        c = f\"{likelihood[i,j]:.2f}\"\n",
    "        ax_likelihood.text(j,i, c, va='center', ha='center', color='black')\n",
    "    for i,j in zip(x.flatten(), y.flatten()):\n",
    "        c = f\"{utility[i,j]:.2f}\"\n",
    "        ax_utility.text(j,i, c, va='center', ha='center', color='black')\n",
    "    # for i,j in zip(x.flatten(), y.flatten()):\n",
    "    #     c = f\"{expected[i,j]:.2f}\"\n",
    "    #     ax_expected.text(j,i, c, va='center', ha='center', color='black')\n",
    "    for i in ind:\n",
    "        v = prior[i]\n",
    "        c = f\"{v:.2f}\"\n",
    "        ax_prior.text(v+0.2, i, c, va='center', ha='center', color='black')\n",
    "    for i in ind:\n",
    "        v = expected[i]\n",
    "        c = f\"{v:.2f}\"\n",
    "        ax_expected.text(i, v, c, va='center', ha='center', color='black')\n",
    "\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @title Helper Functions\n",
    "\n",
    "def compute_marginal(px, py, cor):\n",
    "    \"\"\" Calculate 2x2 joint probabilities given marginals p(x=1), p(y=1) and correlation\n",
    "\n",
    "      Args:\n",
    "        px (scalar): marginal probability of x\n",
    "        py (scalar): marginal probability of y\n",
    "        cor (scalar): correlation value\n",
    "\n",
    "      Returns:\n",
    "        ndarray of size (2, 2): joint probability array of x and y\n",
    "    \"\"\"\n",
    "\n",
    "    p11 = px*py + cor*np.sqrt(px*py*(1-px)*(1-py))\n",
    "    p01 = px - p11\n",
    "    p10 = py - p11\n",
    "    p00 = 1.0 - p11 - p01 - p10\n",
    "\n",
    "    return np.asarray([[p00, p01], [p10, p11]])\n",
    "\n",
    "\n",
    "def compute_cor_range(px,py):\n",
    "    \"\"\" Calculate the allowed range of correlation values given marginals p(x=1)\n",
    "      and p(y=1)\n",
    "\n",
    "    Args:\n",
    "      px (scalar): marginal probability of x\n",
    "      py (scalar): marginal probability of y\n",
    "\n",
    "    Returns:\n",
    "      scalar, scalar: minimum and maximum possible values of correlation\n",
    "    \"\"\"\n",
    "\n",
    "    def p11(corr):\n",
    "        return px*py + corr*np.sqrt(px*py*(1-px)*(1-py))\n",
    "    def p01(corr):\n",
    "        return px - p11(corr)\n",
    "    def p10(corr):\n",
    "        return py - p11(corr)\n",
    "    def p00(corr):\n",
    "        return 1.0 - p11(corr) - p01(corr) - p10(corr)\n",
    "    Cmax = min(fsolve(p01, 0.0), fsolve(p10, 0.0))\n",
    "    Cmin = max(fsolve(p11, 0.0), fsolve(p00, 0.0))\n",
    "    return Cmin, Cmax"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Section 1: Gone Fishin'\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "You were just introduced to the **binary hidden state problem** we are going to explore. You need to decide which side to fish on--the hidden state. We know fish like to school together. On different days the school of fish is either on the left or right side, but we don’t know what the case is today. We define our knowledge about the fish location as a distribution over the random hidden state variable. Using our probabilistic knowledge, also called our **belief** about the hidden state, we will explore how to make the decision about where to fish today, based on what to expect in terms of gains or losses for the decision.\n",
    "The gains and losss are defined by the utility of choosing an action, which is fishing on the left or right. The details of the utilities are described \n",
    "\n",
    "In the next two sections we will consider just the probability of where the fish might be and what you gain or lose by choosing where to fish (leaving Bayesian approaches to the last few sections).\n",
    "\n",
    "Remember, you can either think of your self as a scientist conducting an experiment or as a brain trying to make a decision. The Bayesian approach is the same!\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Section 2: Deciding where to fish \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Utility and expected utility\n",
    "\n",
    "You need to decide where to fish. It may seem obvious - you could just fish on the side where the probability of the fish being is higher! Unfortunately, decisions and actions are always a little more complicated. Deciding to fish may be influenced by more than just the probability of the school of fish being there as we saw by the potential issues of submarines and sunburn. The consequences of the action you take is based on the true (but hidden) state of the world and the action you choose! In our example, fishing on the wrong side, where there aren't many fish, is likely to lead to you spending your afternoon not catching fish and therefore getting a sunburn. The submarine represents a risk to fishing on the right side that is greater than the left side. If you want to know what to expect from taking the action of fishing on one side or the other, you need to calculate the expected utility.\n",
    "\n",
    "You know the (prior) probability that the school of fish is on the left side of the dock today, $P(s = \\textrm{left})$. So, you also know the probability the school is on the right, $P(s = \\textrm{right})$, because these two probabilities must add up to 1.\n",
    "\n",
    "We quantify gains and losses numerically using a **utility** function $U(s,a)$, which describes the consequences of your actions: how much value you gain (or if negative, lose) given the state of the world ($s$) and the action you take ($a$). In our example, our utility can be summarized as:\n",
    "\n",
    "| Utility: U(s,a)   | a = left   | a = right  |\n",
    "| ----------------- |------------|------------|\n",
    "| s = Left          | +2         | -3         |\n",
    "| s = right         | -2         | +1         |\n",
    "\n",
    "To use possible gains and losses to choose an action, we calculate the **expected utility** of that action by weighing these utilities with the probability of that state occuring. This allows us to choose actions by taking probabilities of events into account: we don't care if the outcome of an action-state pair is a loss if the probability of that state is very low. We can formalize this as:\n",
    "\n",
    "$$ \\text{Expected utility of action a} = \\sum_{s}U(s,a)P(s) $$\n",
    "\n",
    "In other words, the expected utility of an action a is the sum over possible states of the utility of that action and state times the probability of that state.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Interactive Demo 2: Exploring the decision\n",
    "\n",
    "Let's start to get a sense of how all this works using the interactive demo below. You can change the probability that the school of fish is on the left side,$p(s = \\textrm{left})$, using the slider. You will see the utility function (a matrix) in the middle and the corresponding expected utility for each action on the right.\n",
    "\n",
    "First, make sure you understand how the expected utility of each action is being computed from the probabilities and the utility values. In the initial state: the probability of the fish being on the left is 0.9 and on the right is 0.1. The expected utility of the action of fishing on the left is then $U(s = \\textrm{left},a = \\textrm{left})p(s = \\textrm{left}) + U(s = \\textrm{right},a = \\textrm{left})p(s = \\textrm{right}) = 2(0.9) + -2(0.1) = 1.6$. Essentially, to get the expected utility of action $a$, you are doing a weighted sum over the relevant column of the utility matrix (corresponding to action $a$) where the weights are the state probabilities.\n",
    "\n",
    "For each of these scenarios, think and discuss first. Then use the demo to try out each and see if your action would have been correct (that is, if the expected value of that action is the highest).\n",
    "\n",
    "\n",
    "1.  You just arrived at the dock for the first time and have no sense of where the fish might be. So you guess that the probability of the school being on the left side is 0.5 (so the probability on the right side is also 0.5). Which side would you choose to fish on given our utility values?\n",
    "2.  You think that the probability of the school being on the left side is very low (0.1) and correspondingly high on the right side (0.9). Which side would you choose to fish on given our utility values?\n",
    "3.  What would you choose if the probability of the school being on the left side is slightly lower than on the right side (0. 4 vs 0.6)?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @markdown Execute this cell to use the widget\n",
    "ps_widget = widgets.FloatSlider(0.9, description='p(s = left)', min=0.0, max=1.0, step=0.01)\n",
    "\n",
    "@widgets.interact(\n",
    "    ps = ps_widget,\n",
    ")\n",
    "def make_utility_plot(ps):\n",
    "    fig = plot_utility(ps)\n",
    "    plt.show(fig)\n",
    "    plt.close(fig)\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove explanation\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "1)  With equal probabilities, the expected utility is higher on the left side,\n",
    "   since that is the side without submarines, so you would choose to fish there.\n",
    "\n",
    "2)  If the probability that the fish is on the right side is high, you would\n",
    "    choose to fish there. The high probability of fish being on the right far outweights\n",
    "   the slightly higher utilities from fishing on the left (as you are unlikely to gain these)\n",
    "\n",
    "3)  If the probability that the fish is on the right side is just slightly higher\n",
    "    than on the left, you would choose the left side as the expected utility is still\n",
    "    higher on the left. Note that in this situation, you are not simply choosing the\n",
    "    side with the higher probability - the utility really matters here for the decision\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "In this section, you have seen that both the utility of various state and action pairs and our knowledge of the probability of each state affects your decision. Importantly, we want our knowledge of the probability of each state to be as accurate as possible! \n",
    "\n",
    "So how do we know these probabilities? We may have prior knowledge from years of fishing at the same dock, learning that the fish are more likely to be on the left side, for example. Of course, we need to update our knowledge (our belief)! To do this, we need to collect more information, or take some measurements! In the next few sections, we will focus on how we improve our knowledge of the probabilities."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Section 3: Likelihood of the fish being on either side\n",
    "\n",
    "*Estimated timing to here from start of tutorial: 25 min*\n",
    " \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "\n",
    "First, we'll think about what it means to take a measurement (also often called an observation or just data) and what it tells you about what the hidden state may be. Specifically, we'll be looking at the **likelihood**, which is the probability of your measurement ($m$) given the hidden state ($s$): $P(m | s)$. Remember that in this case, the hidden state is which side of the dock the school of fish is on.\n",
    "We will watch someone fish (for let's say 10 minutes) and our measurement is whether they catch a fish or not. We know something about what catching a fish means for the likelihood of the fish being on one side or the other."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Think! 3: Guessing the location of the fish\n",
    "\n",
    "Let's say we go to different dock to fish. Here, there are different probabilities of catching fish given the state of the world. At this dock, if you fish on the side of the dock where the fish are, you have a 70% chance of catching a fish. If you fish on the wrong side, you will catch a fish with only 20% probability. These are the likelihoods of observing someone catching a fish! That is, you are taking a measurement by seeing if someone else catches a fish!\n",
    "\n",
    "You see a fisherperson is fishing on the left side.\n",
    "\n",
    "1) Please figure out each of the following (might be easiest to do this separately and then compare notes):\n",
    "- probability of catching a fish given that the school of fish is on the left side, $P(m = \\textrm{catch fish} | s = \\textrm{left} )$\n",
    "- probability of not catching a fish given that the school of fish is on the left side, $P(m = \\textrm{no fish} | s = \\textrm{left})$\n",
    "- probability of catching a fish given that the school of fish is on the right side, $P(m = \\textrm{catch  fish} | s = \\textrm{right})$\n",
    "- probability of not catching a fish given that the school of fish is on the right side, $P(m = \\textrm{no fish} | s = \\textrm{right})$\n",
    "\n",
    "2) If the fisherperson catches a fish, which side would you guess the school is on? Why?\n",
    "\n",
    "3) If the fisherperson does not catch a fish, which side would you guess the school is on? Why?\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove explanation\n",
    "\n",
    "\"\"\"\n",
    "1) The fisherperson is on the left side so:\n",
    "      - P(m = fish | s = left) = 0.7 because they have a 70% chance of catching\n",
    "        a fish when on the same side as the school\n",
    "      - P(m = no fish | s = left) = 0.3 because the probability of catching a fish\n",
    "        and not catching a fish for a given state must add up to 1 as these\n",
    "        are the only options: 1 - 0.7 = 0.3\n",
    "      - P(m = fish | s = right) = 0.2\n",
    "      - P(m = no fish | s = right) = 0.8\n",
    "\n",
    "2) If the fisherperson catches a fish, you would guess the school of fish is on the\n",
    "    left side. This is because the probability of catching a fish given that the\n",
    "   school is on the left side (0.7) is higher than the probability given that\n",
    "   the school is on the right side (0.2).\n",
    "\n",
    "3) If the fisherperson does not catch a fish, you would guess the school of fish is on the\n",
    "    right side. This is because the probability of not catching a fish given that the\n",
    "   school is on the right side (0.8) is higher than the probability given that\n",
    "   the school is on the right side (0.3).\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "In the prior exercise, you tried to guess where the school of fish was based on the measurement you took (watching someone fish). You did this by choosing the state (side where you think the fish are) that maximized the probability of the measurement. In other words, you estimated the state by maximizing the likelihood (the side with the highest probability of measurement given state $P(m|s$)). This is called maximum likelihood estimation (MLE) and you've encountered it before during this course, in the [pre-reqs statistics day](https://compneuro.neuromatch.io/tutorials/W0D5_Statistics/student/W0D5_Tutorial2.html#section-2-2-maximum-likelihood) and on [Model Fitting day](https://compneuro.neuromatch.io/tutorials/W1D3_ModelFitting/student/W1D3_Tutorial2.html)!\n",
    "\n",
    "But, what if you had been going to this dock for years and you knew that the fish were almost always on the left side? This should probably affect how you make your estimate -- you would rely less on the single new measurement and more on your prior knowledge. This is the fundemental idea behind Bayesian inference, as we will see later in this tutorial!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Section 4: Correlation and marginalization\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Section 4.1: Correlation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "In this section, we are going to take a step back for a bit and think more generally about the amount of information shared between two random variables. We want to know how much information you gain when you observe one variable (take a measurement) if you know something about another. We will see that the fundamental concept is the same if we think about two attributes, for example the size and color of the fish, or the prior information and the likelihood.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Think! 4.1: Covarying probability distributions\n",
    "\n",
    "The relationship between the marginal probabilities and the joint probabilities is determined by the correlation between the two random variables - a normalized measure of how much the variables covary. We can also think of this as gaining some information about one of the variables when we observe a measurement from the other. We will think about this more formally in Tutorial 2. \n",
    "\n",
    "Here, we want to think about how the correlation between size and color of these fish changes how much information we gain about one attribute based on the other. See Bonus Section 1 for the formula for correlation.\n",
    "\n",
    "Use the widget below and answer the following questions:\n",
    "\n",
    "1. When the correlation is zero, $\\rho = 0$, what does the distribution of size tell you about color?\n",
    "2. Set $\\rho$ to something small. As you change the probability of golden fish, what happens to the ratio of size probabilities? Set $\\rho$ larger (can be negative). Can you explain the pattern of changes in the probabilities of size as you change the probability of golden fish?\n",
    "3. Set the probability of golden fish and of large fish to around 65%. As the correlation goes towards 1, how often will you see silver large fish?\n",
    "4. What is increasing the (absolute) correlation telling you about how likely you are to see one of the properties if you see a fish with the other?\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @markdown Execute this cell to enable the widget\n",
    "style = {'description_width': 'initial'}\n",
    "gs = GridspecLayout(2,2)\n",
    "\n",
    "cor_widget = widgets.FloatSlider(0.0, description='ρ', min=-1, max=1, step=0.01)\n",
    "px_widget = widgets.FloatSlider(0.5, description='p(color=golden)', min=0.01, max=0.99, step=0.01, style=style)\n",
    "py_widget = widgets.FloatSlider(0.5, description='p(size=large)', min=0.01, max=0.99, step=0.01, style=style)\n",
    "gs[0,0] = cor_widget\n",
    "gs[0,1] = px_widget\n",
    "gs[1,0] = py_widget\n",
    "\n",
    "\n",
    "@widgets.interact(\n",
    "    px=px_widget,\n",
    "    py=py_widget,\n",
    "    cor=cor_widget,\n",
    ")\n",
    "def make_corr_plot(px, py, cor):\n",
    "    Cmin, Cmax = compute_cor_range(px, py) #allow correlation values\n",
    "    cor_widget.min, cor_widget.max = Cmin+0.01, Cmax-0.01\n",
    "    if cor_widget.value > Cmax:\n",
    "        cor_widget.value = Cmax\n",
    "    if cor_widget.value < Cmin:\n",
    "        cor_widget.value = Cmin\n",
    "    cor = cor_widget.value\n",
    "    P = compute_marginal(px,py,cor)\n",
    "    # print(P)\n",
    "    fig = plot_joint_probs(P)\n",
    "    plt.show(fig)\n",
    "    plt.close(fig)\n",
    "    return None\n",
    "\n",
    "# gs[1,1] = make_corr_plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove explanation\n",
    "\n",
    "\"\"\"\n",
    "1. When the correlation is zero, the two properties are completely independent.\n",
    "   This means you don't gain any information about one variable from observing the other.\n",
    "   Importantly, the marginal distribution of one variable is therefore independent of the other.\n",
    "\n",
    "2. The correlation controls the distribution of probability across the joint probabilty table.\n",
    "   The higher the correlation, the more the probabilities are restricted by the fact that both rows\n",
    "   and columns need to sum to one! While the marginal probabilities show the relative weighting, the\n",
    "   absolute probabilities for one quality will become more dependent on the other as the correlation\n",
    "   goes to 1 or -1.\n",
    "\n",
    "3. The correlation will control how much probabilty mass is located on the diagonals. As the\n",
    "   correlation goes to 1 (or -1), the probability of seeing the one of the two pairings has to goes\n",
    "   towards zero!\n",
    "\n",
    "4. If we think about what information we gain by observing one quality, the intution from (3.) tells\n",
    "   us that we know more (have more information) about the other quality as a function of the correlation.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "We have just seen how two random variables can be more or less independent. The more correlated, the less independent, and the more shared information. We also learned that we can marginalize to determine the marginal likelihood of a measurement or to find the marginal probability distribution of two random variables. We are going to now complete our journey towards being fully Bayesian!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Section 4.2: Marginalisation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "\n",
    "We may want to find the probability of one variable while ignoring another: we will do this by averaging out, or marginalizing, the irrelevant variable.\n",
    "\n",
    "We will think of this in two different ways.\n",
    "\n",
    "In the first math exercise, you will think about the case where you know the joint probabilities of two variables and want to figure out the probability of just one variable. To make this explicit, let's assume that a fish has a color that is either gold or silver (our first variable) and a size that is either small or large (our second). We could write out the the **joint probabilities**: the probability of both specific attributes occuring together. For example, the probability of a fish being small and silver, $P(X = \\textrm{small}, Y = \\textrm{silver})$, is 0.4. The following table summarizes our joint probabilities:\n",
    "\n",
    "| P(X, Y)        | Y = silver  | Y = gold  |\n",
    "| -------------- |-------------|-----------|\n",
    "| X = small      | 0.4         | 0.2       |\n",
    "| X = large      | 0.1         | 0.3       |\n",
    "\n",
    "\n",
    "We want to know what the probability of a fish being small  regardless of color. Since the fish are either silver or gold, this would be the probability of a fish being small and silver plus the probability of a fish being small and gold. This is an example of marginalizing, or averaging out, the variable we are not interested in across the rows or columns.. In math speak: $P(X = \\textrm{small}) = \\sum_y{P(X = \\textrm{small}, Y)}$. This gives us a **marginal probability**, a probability of a variable outcome (in this case size), regardless of the other variables (in this case color).\n",
    "\n",
    "More generally, we can marginalize out a second irrelevant variable $y$ by summing over the relevant joint probabilities:\n",
    "\n",
    "$$p(x) = \\sum_y p(x, y) $$\n",
    "\n",
    "\n",
    "In the second math exercise, you will remove an unknown (the hidden state) to find the marginal probability of a measurement. You will do this by marginalizing out the hidden state. In this case, you know the conditional probabilities of the measurement given state and the probabilities of each state. You can marginalize using:\n",
    "\n",
    "$$p(m) = \\sum_s p(m | s) p(s) $$\n",
    "\n",
    "These two ways of thinking about marginalization (as averaging over joint probabilities or conditioning on some variable) are equivalent because the joint probability of two variables equals the conditional probability of the first given the second times the marginal probability of the second:\n",
    "\n",
    "$$p(x, y) = p(x|y)p(y)$$ \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Math Exercise 4.2.1: Computing marginal probabilities\n",
    "\n",
    "To understand the information between two variables, let's first consider the joint probabilities over the size and color of the fish.\n",
    "\n",
    "| P(X, Y)        | Y = silver  | Y = gold  |\n",
    "| -------------- |-------------|-----------|\n",
    "| X = small      | 0.4         | 0.2       |\n",
    "| X = large      | 0.1         | 0.3       |\n",
    "\n",
    "Please complete the following math problems to further practice thinking through probabilities:\n",
    "\n",
    "1. Calculate the probability of a fish being silver.\n",
    "2. Calculate the probability of a fish being small, large, silver, or gold.\n",
    "3. Calculate the probability of a fish being small OR gold. (Hint: $P(A\\ \\textrm{or}\\ B) = P(A) + P(B) - P(A\\ \\textrm{and}\\ B)$)\n",
    "\n",
    "**We don't typically have math exercises in NMA but feel it is important for you to really compute this out yourself. Feel free to use the next cell to write out the math if helpful, or use paper and pen. We recommend doing this exercise individually first and then comparing notes and discussing.**\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "\n",
    "Joint probabilities\n",
    "\n",
    "P( X = small, Y = silver) = 0.4\n",
    "P( X = large, Y = silver) = 0.1\n",
    "P( X = small, Y = gold) = 0.2\n",
    "P( X = large, Y = gold) = 0.3\n",
    "\n",
    "\n",
    "1. P(Y = silver) = ...\n",
    "\n",
    "2. P(X = small or large, Y = silver or gold) = ...\n",
    "\n",
    "3. P( X = small or Y = gold) = ...\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove explanation\n",
    "\n",
    "\"\"\"\n",
    "1. The probability of a fish being silver is the joint probability of it being\n",
    "     small and silver plus the joint probability of it being large and silver:\n",
    "\n",
    "    P(Y = silver) = P(X = small, Y = silver) + P(X = large, Y = silver)\n",
    "                  = 0.4 + 0.1\n",
    "                  = 0.5\n",
    "\n",
    "\n",
    "2. This is all the possibilities as in this scenario, our fish can only be small\n",
    "   or large, silver or gold. So the probability is 1 - the fish has to be at\n",
    "   least one of these.\n",
    "\n",
    "\n",
    "3. First we compute the marginal probabilities\n",
    "   P(X = small) = P(X = small, Y = silver) + P(X = small, Y = gold) = 0.6\n",
    "   P(Y = gold) = P(X = small, Y = gold) + P(X = large, Y = gold) = 0.5\n",
    "\n",
    "   We already know the joint probability: P(X = small, Y = gold) = 0.2\n",
    "\n",
    "   We can now use the given formula:\n",
    "   P( X = small or Y = gold) = P(X = small) + P(Y = gold) - P(X = small, Y = gold)\n",
    "                             = 0.6 + 0.5 - 0.2\n",
    "                             = 0.9\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Math Exercise 4.2.2: Computing marginal likelihood\n",
    "\n",
    "When we normalize to find the posterior, we need to determine the marinal likelihood--or evidence--for the measurement we observed. To do this, we need to marginalize as we just did above to find the probabilities of a color or size. Only, in this case, we are marginalizing to remove a conditioning variable! In this case, let's consider the likelihood of fish (if we observed a fisherperson fishing on the **right**).\n",
    "\n",
    "| p(m\\|s)       | m = fish | m = no fish  |\n",
    "| ------------ | ---------- | -------------- |\n",
    "| s = left     | 0.1      | 0.9          |\n",
    "| s = right    | 0.5      | 0.5          |\n",
    "\n",
    "\n",
    "The table above shows us the **likelihoods**, just as we explored earlier.\n",
    "\n",
    "You want to know the total probability of a fish being caught, $P(m = \\textrm{fish})$, by the fisherperson fishing on the right. (You would need this to calcualate the posterior.) To do this, you will need to consider the prior probability, $p(s)$, and marginalize over the hidden states!\n",
    "\n",
    "This is an example of marginalizing, or conditioning away, the variable we are not interested in as well.\n",
    "\n",
    "Please complete the following math problems to further practice thinking through probabilities:\n",
    "\n",
    "1. Calculate the marginal likelihood of the fish being caught, $P(m = \\textrm{fish})$, if the priors are: $p(s = \\textrm{left}) = 0.3$ and $p(s = \\textrm{right}) = 0.7$.\n",
    "2. Calculate the marginal likelihood of the fish being caught,  $P(m = \\textrm{fish})$, if the priors are: $p(s = \\textrm{left}) = 0.6$ and $p(s = \\textrm{right}) = 0.4$.\n",
    "\n",
    "**We don't typically have math exercises in NMA but feel it is important for you to really compute this out yourself. Feel free to use the next cell to write out the math if helpful, or use paper and pen. We recommend doing this exercise individually first and then comparing notes and discussing.**\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "\n",
    "Priors\n",
    "P(s = left) = 0.3\n",
    "P(s = right) = 0.7\n",
    "\n",
    "Likelihoods\n",
    "P(m = fish | s = left) = 0.1\n",
    "P(m = fish | s = right) = 0.5\n",
    "P(m = no fish | s = left) = 0.9\n",
    "P(m = no fish | s = right) = 0.5\n",
    "\n",
    "1. P(m = fish) = ...\n",
    "\n",
    "2. P(m = fish) = ...\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove explanation\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "1. The marginal likelihood (evidence) is\n",
    "\n",
    "    P(m = fish) = P(m = fish, s = left) + P(m = fish, s = right)\n",
    "                = P(m = fish, s = left)P(s = left) + P(m = fish, s = right)P(s = right)\n",
    "                = 0.1 * 0.3 + .5 * .7\n",
    "            = 0.38\n",
    "\n",
    "\n",
    "2. The marginal likelihood (evidence) is\n",
    "\n",
    "    P(m = fish) = P(m = fish, s = left) + P(m = fish, s = right)\n",
    "                = P(m = fish, s = left)P(s = left) + P(m = fish, s = right)P(s = right)\n",
    "                = 0.1 * 0.6 + .5 * .4\n",
    "                = 0.26\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Section 5: Bayes' Rule and the Posterior\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "Marginalization is going to be used to combine our prior knowlege, which we call the **prior**, and our new information from a measurement, the **likelihood**. Only in this case, the information we gain about the hidden state we are interested in, where the fish are, is based on the relationship between the probabilities of the measurement and our prior. \n",
    "\n",
    "We can now calculate the full posterior distribution for the hidden state ($s$) using Bayes' Rule. As we've seen, the posterior is proportional the the prior times the likelihood. This means that the posterior probability of the hidden state ($s$) given a measurement ($m$) is proportional to the likelihood of the measurement given the state times the prior probability of that state:\n",
    "\n",
    "$$ P(s | m) \\propto P(m | s) P(s)  $$\n",
    "\n",
    "We say proportional to instead of equal because we need to normalize to produce a full probability distribution:\n",
    "\n",
    "$$ P(s | m) = \\frac{P(m | s) P(s)}{P(m)}  $$\n",
    "\n",
    "Normalizing by this $P(m)$ means that our posterior is a complete probability distribution that sums or integrates to 1 appropriately. We now can use this new, complete probability distribution for any future inference or decisions we like! In fact, as we will see tomorrow, we can use it as a new prior! Finally, we often call this probability distribution our beliefs over the hidden states, to emphasize that it is our subjective knowlege about the hidden state.\n",
    "\n",
    "For many complicated cases, like those we might be using to model behavioral or brain inferences, the normalization term can be intractable or extremely complex to calculate. We can be careful to choose probability distributions were we can analytically calculate the posterior probability or numerical approximation is reliable. Better yet, we sometimes don't need to bother with this normalization! The normalization term, $P(m)$, is the probability of the measurement. This does not depend on state so is essentially a constant we can often ignore. We can compare the unnormalized posterior distribution values for different states because how they relate to each other is unchanged when divided by the same constant. We will see how to do this to compare evidence for different hypotheses tomorrow. (It's also used to compare the likelihood of models fit using maximum likelihood estimation)\n",
    "\n",
    "In this relatively simple example, we can compute the marginal likelihood $P(m)$ easily by using:\n",
    "$$P(m) = \\sum_s P(m | s) P(s)$$\n",
    "We can then normalize so that we deal with the full posterior distribution.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Math Exercise 5: Calculating a posterior probability\n",
    "\n",
    "Our prior is $p(s = \\textrm{left}) = 0.3$ and $p(s = \\textrm{right}) = 0.7$. In the video, we learned that the chance of catching a fish given they fish on the same side as the school was 50%. Otherwise, it was 10%. We observe a person fishing on the left side. Our likelihood is: \n",
    "\n",
    "\n",
    "| Likelihood: p(m \\| s) | m = fish   | m = no fish  |\n",
    "| ----------------- |----------|----------|\n",
    "| s = left          | 0.5          | 0.5         |\n",
    "| s = right         | 0.1        |  0.9       |\n",
    "\n",
    "\n",
    "Calculate the posterior probability (on paper) that:\n",
    "\n",
    "1. The school is on the left if the fisherperson catches a fish: $p(s = \\textrm{left} | m = \\textrm{fish})$ (hint: normalize by computing $p(m = \\textrm{fish})$)\n",
    "2. The school is on the right if the fisherperson does not catch a fish: $p(s = \\textrm{right} | m = \\textrm{no fish})$\n",
    "\n",
    "**We don't typically have math exercises in NMA but feel it is important for you to really compute this out yourself. Feel free to use the next cell to write out the math if helpful, or use paper and pen. We recommend doing this exercise individually first and then comparing notes and discussing.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Priors\n",
    "p(s = left) = 0.3\n",
    "p(s = right) = 0.7\n",
    "\n",
    "Likelihoods\n",
    "P(m = fish | s = left) = 0.5\n",
    "P(m = fish | s = right) = 0.1\n",
    "P(m = no fish | s = left) = 0.5\n",
    "P(m = no fish | s = right) = 0.9\n",
    "\n",
    "1. p( s = left | m = fish) = ...\n",
    "\n",
    "2. p( s = right | m = no fish ) = ...\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove explanation\n",
    "\n",
    "\"\"\"\n",
    "1. Using Bayes rule, we know that P(s = left | m = fish) = P(m = fish | s = left)P(s = left) / P(m = fish)\n",
    "\n",
    "   Let's first compute P(m = fish):\n",
    "   P(m = fish) =  P(m = fish | s = left)P(s = left) +  P(m = fish | s = right)P(s = right)\n",
    "               = 0.5 * 0.3 + .1 * .7\n",
    "               = 0.22\n",
    "\n",
    "   Now we can plug in all parts of Bayes rule:\n",
    "   P(s = left | m = fish) = P(m = fish | s = left)P(s = left) / P(m = fish)\n",
    "                          = 0.5 * 0.3 / 0.22\n",
    "                          = 0.68\n",
    "\n",
    "2. Using Bayes rule, we know that P(s = right | m = no fish) = P(m = no fish | s = right)P(s = right) / P(m = no fish)\n",
    "\n",
    "   Let's first compute P(m = no fish):\n",
    "   P(m = no fish) = P(m = no fish | s = left)P(s = left) +  P(m = no fish | s = right)P(s = right)\n",
    "                  = 0.5 * 0.3 + .9 * .7\n",
    "                  = 0.78\n",
    "\n",
    "   Now we can plug in all parts of Bayes rule:\n",
    "   P(s = right | m = no fish) = P(m = no fish | s = right)P(s = right) / P(m = no fish)\n",
    "                              = 0.9 * 0.7 / 0.78\n",
    "                              = 0.81\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Coding Exercise 5: Computing Posteriors\n",
    "\n",
    "Let's implement our above math to be able to compute posteriors for different priors and likelihoods.\n",
    "\n",
    "As before, our prior is $p(s = \\textrm{left}) = 0.3$ and $p(s = \\textrm{right}) = 0.7$. In the video, we learned that the chance of catching a fish given they fish on the same side as the school was 50%. Otherwise, it was 10%. We observe a person fishing on the left side. Our likelihood is: \n",
    "\n",
    "\n",
    "| Likelihood: p(m \\| s) | m = fish   | m = no fish  |\n",
    "| ----------------- |----------|----------|\n",
    "| s = left          | 0.5          | 0.5         |\n",
    "| s = right         | 0.1        |  0.9       |\n",
    "\n",
    "\n",
    "We want our full posterior to take the same 2 by 2 form. Make sure the outputs match your math answers!\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def compute_posterior(likelihood, prior):\n",
    "  \"\"\" Use Bayes' Rule to compute posterior from likelihood and prior\n",
    "\n",
    "  Args:\n",
    "    likelihood (ndarray): i x j array with likelihood probabilities where i is\n",
    "                    number of state options, j is number of measurement options\n",
    "    prior (ndarray): i x 1 array with prior probability of each state\n",
    "\n",
    "  Returns:\n",
    "    ndarray: i x j array with posterior probabilities where i is\n",
    "            number of state options, j is number of measurement options\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "  #################################################\n",
    "  ## TODO for students ##\n",
    "  # Fill out function and remove\n",
    "  raise NotImplementedError(\"Student exercise: implement compute_posterior\")\n",
    "  #################################################\n",
    "\n",
    "  # Compute unnormalized posterior (likelihood times prior)\n",
    "  posterior = ... # first row is s = left, second row is s = right\n",
    "\n",
    "  # Compute p(m)\n",
    "  p_m = np.sum(posterior, axis = 0)\n",
    "\n",
    "  # Normalize posterior (divide elements by p_m)\n",
    "  posterior /= ...\n",
    "\n",
    "  return posterior\n",
    "\n",
    "# Make prior\n",
    "prior = np.array([0.3, 0.7]).reshape((2, 1)) # first row is s = left, second row is s = right\n",
    "\n",
    "# Make likelihood\n",
    "likelihood = np.array([[0.5, 0.5], [0.1, 0.9]]) # first row is s = left, second row is s = right\n",
    "\n",
    "# Compute posterior\n",
    "posterior = compute_posterior(likelihood, prior)\n",
    "\n",
    "# Visualize\n",
    "plot_prior_likelihood_posterior(prior, likelihood, posterior)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "\n",
    "def compute_posterior(likelihood, prior):\n",
    "  \"\"\" Use Bayes' Rule to compute posterior from likelihood and prior\n",
    "\n",
    "  Args:\n",
    "    likelihood (ndarray): i x j array with likelihood probabilities where i is\n",
    "                    number of state options, j is number of measurement options\n",
    "    prior (ndarray): i x 1 array with prior probability of each state\n",
    "\n",
    "  Returns:\n",
    "    ndarray: i x j array with posterior probabilities where i is\n",
    "            number of state options, j is number of measurement options\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "  # Compute unnormalized posterior (likelihood times prior)\n",
    "  posterior = likelihood * prior # first row is s = left, second row is s = right\n",
    "\n",
    "  # Compute p(m)\n",
    "  p_m = np.sum(posterior, axis = 0)\n",
    "\n",
    "  # Normalize posterior (divide elements by p_m)\n",
    "  posterior /= p_m\n",
    "\n",
    "  return posterior\n",
    "\n",
    "\n",
    "# Make prior\n",
    "prior = np.array([0.3, 0.7]).reshape((2, 1)) # first row is s = left, second row is s = right\n",
    "\n",
    "# Make likelihood\n",
    "likelihood = np.array([[0.5, 0.5], [0.1, 0.9]]) # first row is s = left, second row is s = right\n",
    "\n",
    "# Compute posterior\n",
    "posterior = compute_posterior(likelihood, prior)\n",
    "\n",
    "# Visualize\n",
    "plot_prior_likelihood_posterior(prior, likelihood, posterior)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Interactive Demo 5: What affects the posterior?\n",
    "\n",
    "Now that we can understand the implementation of *Bayes rule*, let's vary the parameters of the prior and likelihood to see how changing the prior and likelihood affect the posterior. \n",
    "\n",
    "In the demo below, you can change the prior by playing with the slider for $p( s = left)$. You can also change the likelihood by changing the probability of catching a fish given that the school is on the left and the probability of catching a fish given that the school is on the right. The fisherperson you are observing is fishing on the left.\n",
    " \n",
    "\n",
    "1.   Keeping the likelihood constant, when does the prior have the strongest influence over the posterior? Meaning, when does the posterior look most like the prior no matter whether a fish was caught or not?\n",
    "2.   What happens if the likelihoods for catching a fish are similar when you fish on the correct or incorrect side?\n",
    "3.  Set the prior probability of the state = left to 0.6 and play with the likelihood. When does the likelihood exert the most influence over the posterior?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @markdown Execute this cell to enable the widget\n",
    "# style = {'description_width': 'initial'}\n",
    "ps_widget = widgets.FloatSlider(0.3, description='p(s = left)',\n",
    "                                min=0.01, max=0.99, step=0.01)\n",
    "p_a_s1_widget = widgets.FloatSlider(0.5, description='p(fish on left | state = left)',\n",
    "                                    min=0.01, max=0.99, step=0.01, style=style, layout=Layout(width='370px'))\n",
    "p_a_s0_widget = widgets.FloatSlider(0.1, description='p(fish on left | state = right)',\n",
    "                                    min=0.01, max=0.99, step=0.01, style=style, layout=Layout(width='370px'))\n",
    "# observed_widget = widgets.Checkbox(value=False, description='Observed fish (m)',\n",
    "#                                  disabled=False, indent=False,\n",
    "#                                  layout=Layout(display=\"flex\", justify_content=\"center\"))\n",
    "\n",
    "observed_widget = ToggleButtons(options=['Fish', 'No Fish'],\n",
    "    description='Observation (m) on the left:', disabled=False, button_style='',\n",
    "    layout=Layout(width='auto', display=\"flex\"),\n",
    "    style={'description_width': 'initial'}\n",
    ")\n",
    "\n",
    "widget_ui = VBox([ps_widget,\n",
    "                  HBox([p_a_s1_widget, p_a_s0_widget]),\n",
    "                  observed_widget])\n",
    "widget_out = interactive_output(plot_prior_likelihood,\n",
    "                                {'ps': ps_widget,\n",
    "                                'p_a_s1': p_a_s1_widget,\n",
    "                                'p_a_s0': p_a_s0_widget,\n",
    "                                'measurement': observed_widget})\n",
    "display(widget_ui, widget_out)\n",
    "\n",
    "# @widgets.interact(\n",
    "#     ps=ps_widget,\n",
    "#     p_a_s1=p_a_s1_widget,\n",
    "#     p_a_s0=p_a_s0_widget,\n",
    "#     m_right=observed_widget\n",
    "# )\n",
    "# def make_prior_likelihood_plot(ps,p_a_s1,p_a_s0,m_right):\n",
    "#     fig = plot_prior_likelihood(ps,p_a_s1,p_a_s0,m_right)\n",
    "#     plt.show(fig)\n",
    "#     plt.close(fig)\n",
    "#     return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove explanation\n",
    "\n",
    "\"\"\"\n",
    "1)  The prior exerts a strong influence over the posterior when it is very informative: when\n",
    "    the probability of the school being on one side or the other. If the prior that the fish are\n",
    "    on the left side is very high (like 0.9), the posterior probability of the state being left is\n",
    "    high regardless of the measurement.\n",
    "\n",
    "2)  When the likelihoods are similar, the information gained from catching a fish or not is less informative.\n",
    "    Intuitively, if you were about as likely to catch a fish regardless of the true location, then catching a fish\n",
    "    doesn't tell you very much! The differences between the likelihoods is a way of thinking about how much information\n",
    "    we can gain. You can try to figure out why, as we've given you all the clues...\n",
    "\n",
    "3)  Similarly to the prior, the likelihood exerts the most influence when it is informative: when catching\n",
    "    a fish tells you a lot of information about which state is likely. For example, if the probability of the\n",
    "    fisherperson catching a fish if he is fishing on the right side and the school is on the left is 0\n",
    "    (p fish | s = left) = 0 and the probability of catching a fish if the school is on the right is 1, the\n",
    "     prior does not affect the posterior at all. The measurement tells you the hidden state completely.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "# Section 6: Making Bayesian fishing decisions\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "\n",
    "We will explore how to consider the expected utility of an action based on our belief (the posterior distribution) about where we think the fish are. Now we have all the components of a Bayesian decision: our prior information, the likelihood given a measurement, the posterior distribution (belief) and our utility (the gains and losses). This allows us to consider the relationship between the true value of the hidden state, $s$, and what we *expect* to get if we take action, $a$, based on our belief!\n",
    "\n",
    "Let's use the following widget to think about the relationship between these probability distributions and utility function."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Interactive Demo! 6: What is more important, the probabilities or the utilities?\n",
    "\n",
    "We are now going to put everything we've learned together to gain some intuitions for how each of the elements that goes into a Bayesian decision comes together. Remember, the common assumption in neuroscience, psychology, economics, ecology, etc. is that we (humans and animals) are tying to maximize our expected utility. There is a lot going on in this demo as it brings everything in this tutorial together in one place - please spend time making sure you understand the controls and the plots, especially how everything relates together.\n",
    "\n",
    "1. Can you find a situation where the expected utility is the same for both actions?\n",
    "2. What is more important for determining the expected utility: the prior or a new measurement (the likelihood)?\n",
    "3. Why is this a normative model?\n",
    "4. Can you think of ways in which this model would need to be extended to describe human or animal behavior?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @markdown Execute this cell to enable the widget\n",
    "# style = {'description_width': 'initial'}\n",
    "\n",
    "ps_widget = widgets.FloatSlider(0.3, description='p(s = left)',\n",
    "                                min=0.01, max=0.99, step=0.01, layout=Layout(width='300px'))\n",
    "p_a_s1_widget = widgets.FloatSlider(0.5, description='p(fish on left | state = left)',\n",
    "                                    min=0.01, max=0.99, step=0.01, style=style, layout=Layout(width='370px'))\n",
    "p_a_s0_widget = widgets.FloatSlider(0.1, description='p(fish on left | state = right)',\n",
    "                                    min=0.01, max=0.99, step=0.01, style=style, layout=Layout(width='370px'))\n",
    "\n",
    "observed_widget = ToggleButtons(options=['Fish', 'No Fish'],\n",
    "    description='Observation (m) on the left:', disabled=False, button_style='',\n",
    "    layout=Layout(width='auto', display=\"flex\"),\n",
    "    style={'description_width': 'initial'}\n",
    ")\n",
    "\n",
    "widget_ui = VBox([ps_widget,\n",
    "                  HBox([p_a_s1_widget, p_a_s0_widget]),\n",
    "                  observed_widget])\n",
    "\n",
    "widget_out = interactive_output(plot_prior_likelihood_utility,\n",
    "                                {'ps': ps_widget,\n",
    "                                'p_a_s1': p_a_s1_widget,\n",
    "                                'p_a_s0': p_a_s0_widget,\n",
    "                                'measurement': observed_widget})\n",
    "display(widget_ui, widget_out)\n",
    "\n",
    "# @widgets.interact(\n",
    "#     ps=ps_widget,\n",
    "#     p_a_s1=p_a_s1_widget,\n",
    "#     p_a_s0=p_a_s0_widget,\n",
    "#     m_right=observed_widget\n",
    "# )\n",
    "# def make_prior_likelihood_utility_plot(ps, p_a_s1, p_a_s0,m_right):\n",
    "#     fig = plot_prior_likelihood_utility(ps, p_a_s1, p_a_s0,m_right)\n",
    "#     plt.show(fig)\n",
    "#     plt.close(fig)\n",
    "#     return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove explanation\n",
    "\n",
    "\"\"\"\n",
    "1. There are actually many (infinite) combinations that can produce the same\n",
    "   expected utility for both actions: but the posterior probabilities will always\n",
    "   have to balance out the differences in the utility function. So, what is\n",
    "   important is that for a given utility function, there will be some 'point\n",
    "   of indifference'\n",
    "\n",
    "2. What matters is the relative information: if the prior is close to 50/50,\n",
    "   then the likelihood has more infuence, if the likelihood is 50/50 given a\n",
    "   measurement (the measurement is uninformative), the prior is more important.\n",
    "   But the critical insite from Bayes Rule and the Bayesian approach is that what\n",
    "   matters is the relative information you gain from a measurement, and that\n",
    "   you can use all of this information for your decision.\n",
    "\n",
    "3. The model gives us a very precise way to think about how we *should* combine\n",
    "   information and how we *should* act, GIVEN some assumption about our goals.\n",
    "   In this case, if we assume we are trying to maximize expected utility--we can\n",
    "   state what an animal or subject should do.\n",
    "\n",
    "4. There are lots of possible extensions. Humans may not always try to maximize\n",
    "   utility; humans and animals might not be able to calculate or represent probabiltiy\n",
    "   distributions exactly; The utility function might be more complicated; etc.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Summary\n",
    "\n",
    "In this tutorial, you learned about combining prior information with new measurements to update your knowledge using Bayes Rule, in the context of a fishing problem. \n",
    "\n",
    "Specifically, we covered:\n",
    "\n",
    "* That the likelihood is the probability of the measurement given some hidden state\n",
    "\n",
    "* That how the prior and likelihood interact to create the posterior, the probability of the hidden state given a measurement, depends on how they covary\n",
    "\n",
    "* That utility is the gain from each action and state pair, and the expected utility for an action is the sum of the utility for all state pairs, weighted by the probability of that state happening. You can then choose the action with highest expected utility.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Bonus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Bonus Section 1: Correlation Formula\n",
    "To understand the way we calculate the correlation, we need to review the definition of covariance and correlation.\n",
    "\n",
    "Covariance:\n",
    "\n",
    "$$\n",
    "cov(X,Y) = \\sigma_{XY} = E[(X - \\mu_{x})(Y - \\mu_{y})] = E[X]E[Y] - \\mu_{x}\\mu_{y}\n",
    "$$\n",
    "\n",
    "Correlation:\n",
    "\n",
    "$$\n",
    "\\rho_{XY} = \\frac{cov(Y,Y)}{\\sqrt{V(X)V(Y)}} = \\frac{\\sigma_{XY}}{\\sigma_{X}\\sigma_{Y}}\n",
    "$$"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "jWmfLbhzBpfz"
   ],
   "include_colab_link": true,
   "name": "W7_Tutorial1",
   "provenance": [],
   "toc_visible": true
  },
  "interpreter": {
   "hash": "9516f62da91337f10c2adbe814d9c63a4b08f8271333386358218606edb781e3"
  },
  "kernel": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "kernelspec": {
   "display_name": "Python 3.7.11 64-bit ('pw3': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
